{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cell构建及其子类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "MindSpore的`Cell`类是构建所有网络的基类，也是网络的基本单元。当用户需要自定义网络时，需要继承`Cell`类，并重写`__init__`方法和`construct`方法。\n",
    "\n",
    "损失函数、优化器和模型层等本质上也属于网络结构，也需要继承`Cell`类才能实现功能，同样用户也可以根据业务需求自定义这部分内容。\n",
    "\n",
    "本节内容首先将会介绍`Cell`类的关键成员函数，然后介绍基于`Cell`实现的MindSpore内置损失函数、优化器和模型层及使用方法，最后通过实例介绍如何利用`Cell`类构建自定义网络。\n",
    "\n",
    "> 本文档适用于GPU和Ascend环境。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 关键成员函数\n",
    "### construct方法\n",
    "`Cell`类重写了`__call__`方法，在`Cell`类的实例被调用时，会执行`construct`方法。网络结构在`construct`方法里面定义。\n",
    "\n",
    "下面的样例中，我们构建了一个简单的网络实现卷积计算功能。构成网络的算子在`__init__`中定义，在`construct`方法里面使用，用例的网络结构为`Conv2d`->`BiasAdd`。\n",
    "\n",
    "在`construct`方法中，`x`为输入数据，`output`是经过网络结构计算后得到的计算结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore import Parameter\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self, in_channels=10, out_channels=20, kernel_size=3):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv2d = ops.Conv2D(out_channels, kernel_size)\n",
    "        self.bias_add = ops.BiasAdd()\n",
    "        self.weight = Parameter(\n",
    "            initializer('normal', [out_channels, in_channels, kernel_size, kernel_size]),\n",
    "            name='conv.weight')\n",
    "\n",
    "    def construct(self, x):\n",
    "        output = self.conv2d(x, self.weight)\n",
    "        output = self.bias_add(output, self.bias)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### parameters_dict\n",
    "`parameters_dict`方法识别出网络结构中所有的参数，返回一个以key为参数名，value为参数值的`OrderedDict`。\n",
    "\n",
    "`Cell`类中返回参数的方法还有许多，例如`get_parameters`、`trainable_params`等，具体使用方法可以参见[API文档](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/nn/mindspore.nn.Cell.html)。\n",
    "\n",
    "代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "odict_keys(['conv.weight'])\n",
      "Parameter (name=conv.weight, value=[[[[-3.79058602e-03 -7.49116531e-03  6.54395670e-03]\n",
      "   [ 2.95572029e-03  2.35575647e-03 -1.18450755e-02]\n",
      "   [-8.93178117e-03  4.29675775e-03 -1.51903387e-02]]\n",
      "\n",
      "  [[-8.24289862e-03  1.45210586e-02 -7.69860717e-03]\n",
      "   [-6.99218456e-03  3.68638476e-03 -6.29653595e-03]\n",
      "   [ 3.95354349e-03 -8.72690696e-03 -3.47765832e-04]]\n",
      "\n",
      "  [[ 2.51406804e-04  1.54294372e-02 -7.59741105e-03]\n",
      "   [ 2.09463271e-03  5.88058168e-03 -1.34559879e-02]\n",
      "   [ 1.62827142e-03 -5.09303994e-03 -1.03666903e-02]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 1.16897747e-02 -5.39411651e-03  1.76131574e-03]\n",
      "   [-9.93336271e-03 -8.59452412e-03  1.45569425e-02]\n",
      "   [-2.64248601e-03  1.31645892e-02 -7.67136551e-03]]\n",
      "\n",
      "  [[ 6.72375411e-03 -1.37590384e-02 -1.58686785e-03]\n",
      "   [ 2.21157447e-03 -8.83827638e-03 -6.06854726e-03]\n",
      "   [ 2.88759172e-03  4.16931603e-03  2.94104894e-03]]\n",
      "\n",
      "  [[-1.19857751e-02  6.25934359e-03 -1.07316561e-02]\n",
      "   [ 1.64881181e-02  6.71197521e-03 -2.02514580e-03]\n",
      "   [ 4.66409093e-03  3.84772522e-03 -6.84277294e-03]]]\n",
      "\n",
      "\n",
      " [[[-2.30321679e-02  1.37405575e-03  7.78581668e-03]\n",
      "   [ 9.03773773e-03  8.85955710e-03 -2.75525171e-03]\n",
      "   [ 5.66808798e-04  1.16566438e-02  6.69721002e-03]]\n",
      "\n",
      "  [[ 1.33153396e-02 -1.46484599e-02  1.41494293e-02]\n",
      "   [ 5.70760155e-03 -8.79157893e-03  1.77908351e-03]\n",
      "   [-1.48991437e-03 -8.95473361e-03 -2.80581210e-02]]\n",
      "\n",
      "  [[-2.99305248e-04 -6.61530113e-03  1.77572258e-02]\n",
      "   [-6.06338494e-03  2.12488417e-03 -2.07193266e-03]\n",
      "   [-2.74577906e-04  1.34050875e-04  4.34431154e-03]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[-4.24741162e-03 -2.46842522e-02 -9.57083143e-03]\n",
      "   [-9.77984443e-03 -1.50003182e-02 -1.03083439e-02]\n",
      "   [-1.96074361e-05  1.71559851e-03  2.20309221e-03]]\n",
      "\n",
      "  [[ 1.17843016e-03  5.20134857e-03 -1.58897135e-02]\n",
      "   [ 1.17921419e-02  1.38110323e-02 -4.29972541e-04]\n",
      "   [ 1.70474686e-02 -1.27913784e-02  1.61680888e-04]]\n",
      "\n",
      "  [[ 1.05532361e-02 -1.90200715e-03  1.97972544e-03]\n",
      "   [ 2.55954615e-03 -7.95821752e-03 -9.22696106e-03]\n",
      "   [ 4.60846350e-03 -2.81826360e-04  8.02740548e-03]]]\n",
      "\n",
      "\n",
      " [[[-7.23145131e-05  5.42994682e-03 -1.58438943e-02]\n",
      "   [ 5.99237718e-03 -1.55890090e-02  3.43591929e-03]\n",
      "   [-1.91357806e-02  6.33356208e-03 -1.24399252e-02]]\n",
      "\n",
      "  [[-7.35368056e-04 -4.21275850e-03  5.67453075e-03]\n",
      "   [-1.60260107e-02 -5.28890407e-03 -3.44187859e-03]\n",
      "   [ 3.84379062e-03 -3.95327667e-03  3.84928752e-03]]\n",
      "\n",
      "  [[-9.22779087e-03 -1.16742654e-02 -1.78000852e-02]\n",
      "   [ 8.70509911e-03  8.91428033e-04  1.39736349e-03]\n",
      "   [ 2.62823584e-03 -7.26819690e-03  8.50942172e-03]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 5.03179757e-03 -1.53136225e-02 -1.64501485e-03]\n",
      "   [-1.38264569e-02 -9.66335926e-03 -7.14887958e-03]\n",
      "   [ 7.47693982e-03 -2.15925393e-03  2.28523044e-03]]\n",
      "\n",
      "  [[-7.27958838e-03  5.90789504e-03 -1.05854226e-02]\n",
      "   [-9.39854514e-03 -2.83592567e-03  1.88150327e-03]\n",
      "   [ 9.45885293e-03  5.82330208e-03 -7.18125782e-04]]\n",
      "\n",
      "  [[-3.98916230e-02  1.34961482e-03 -5.71780885e-03]\n",
      "   [-1.22198826e-02 -6.02510467e-04 -6.99814921e-03]\n",
      "   [ 8.21080804e-03  2.55070766e-03 -1.42307254e-03]]]\n",
      "\n",
      "\n",
      " ...\n",
      "\n",
      "\n",
      " [[[-1.18070683e-02  2.44325446e-03 -5.89270936e-03]\n",
      "   [ 6.21154904e-04 -9.12335236e-03  1.36900656e-02]\n",
      "   [-6.07193820e-03 -8.44419561e-03  8.41168128e-03]]\n",
      "\n",
      "  [[-1.43579794e-02  4.77887969e-03  2.16902955e-03]\n",
      "   [-7.26557663e-03  7.00268615e-03  1.37364501e-02]\n",
      "   [-2.74261786e-03  1.81445992e-03  3.85763566e-03]]\n",
      "\n",
      "  [[ 1.62989385e-02  8.25487915e-03  1.39235612e-02]\n",
      "   [-2.63211294e-03  6.86650863e-03 -5.32033388e-03]\n",
      "   [ 2.64116284e-03  2.78708292e-03  1.78428963e-02]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 1.02946069e-02  9.44312662e-03  5.82270976e-03]\n",
      "   [-1.89918187e-02 -8.64485651e-03  4.39302734e-04]\n",
      "   [-4.84286249e-03  9.51220747e-04  1.60740817e-03]]\n",
      "\n",
      "  [[-8.78979266e-03 -1.28001496e-02  4.36241785e-03]\n",
      "   [ 2.24327948e-03  4.94629331e-03  9.89654008e-03]\n",
      "   [-9.22667910e-04 -3.26456456e-03 -5.66262193e-03]]\n",
      "\n",
      "  [[-2.32832320e-03 -4.71439725e-03 -1.11497473e-02]\n",
      "   [-1.02671739e-02 -6.24760240e-03 -1.33280521e-02]\n",
      "   [ 2.82080821e-03 -2.04679416e-03 -8.69380031e-03]]]\n",
      "\n",
      "\n",
      " [[[ 1.34886345e-02 -3.59563460e-03  2.12849583e-02]\n",
      "   [-2.90595368e-03  2.10898672e-03  3.86864692e-03]\n",
      "   [ 9.81558673e-03 -1.56686474e-02 -4.62771254e-03]]\n",
      "\n",
      "  [[ 2.14769179e-03  2.08799876e-02 -3.52404662e-04]\n",
      "   [ 5.57923131e-03 -2.61301436e-02  2.86811846e-03]\n",
      "   [ 1.73691977e-02  8.55967216e-03  6.73174858e-03]]\n",
      "\n",
      "  [[ 3.81736667e-04  3.17116547e-03 -2.09367397e-04]\n",
      "   [-1.82383857e-03  5.18669141e-04  1.36037858e-03]\n",
      "   [ 1.26522558e-03 -9.87494178e-03 -2.81262398e-03]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[-6.00081123e-03  3.29898112e-03  9.60681355e-04]\n",
      "   [ 6.54833624e-04  6.40623132e-03  9.86214075e-03]\n",
      "   [-1.15496363e-03  1.82333821e-03  8.21203180e-03]]\n",
      "\n",
      "  [[-4.22242563e-03 -1.45574019e-03  1.65237219e-03]\n",
      "   [ 2.25831345e-02  1.29176928e-02 -8.02992750e-03]\n",
      "   [-1.28453244e-02  1.04181040e-02  1.31541276e-02]]\n",
      "\n",
      "  [[ 2.64914753e-03 -5.67125762e-03  5.63753722e-03]\n",
      "   [-5.36110671e-03 -1.70621462e-02 -1.58901960e-02]\n",
      "   [-4.35212307e-04 -2.71965452e-02  8.70116334e-03]]]\n",
      "\n",
      "\n",
      " [[[ 1.78805273e-02  1.34418588e-02 -1.11730220e-02]\n",
      "   [ 7.79538928e-03 -3.09715117e-03 -1.41505226e-02]\n",
      "   [ 1.18346307e-02 -6.39158720e-03 -2.49270606e-03]]\n",
      "\n",
      "  [[-1.13375848e-02  1.25055865e-03  1.15597649e-02]\n",
      "   [ 1.61244385e-02 -5.76407555e-03  1.13344463e-02]\n",
      "   [ 2.75934278e-03  2.99722832e-02 -1.79623044e-03]]\n",
      "\n",
      "  [[ 1.37456162e-02 -1.54584693e-02 -5.62199857e-03]\n",
      "   [-1.06104644e-05 -6.02990622e-04  1.15524372e-02]\n",
      "   [ 1.15251923e-02  1.74393572e-04  1.89418159e-02]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 5.94240660e-03 -9.01879929e-03  1.97018851e-02]\n",
      "   [ 1.05262119e-02 -3.07446276e-03 -8.22904427e-03]\n",
      "   [-9.04161111e-03 -1.01207430e-02  2.76559056e-03]]\n",
      "\n",
      "  [[-7.84084387e-03  2.30750605e-03 -9.32552479e-03]\n",
      "   [ 3.35139083e-03 -4.04170249e-03 -9.69938654e-03]\n",
      "   [-1.09762466e-02 -8.33445974e-03 -7.32679293e-03]]\n",
      "\n",
      "  [[-6.28533401e-03 -1.29033374e-02  1.17999716e-02]\n",
      "   [ 2.92376219e-03  1.01038693e-02 -5.01759350e-03]\n",
      "   [-2.73497240e-03  2.06234190e-03  9.91478097e-03]]]])\n"
     ]
    }
   ],
   "source": [
    "net = Net()\n",
    "result = net.parameters_dict()\n",
    "print(result.keys())\n",
    "print(result['conv.weight'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "样例中的`Net`采用上文构造网络的用例，打印了网络中所有参数的名字和`conv.weight`参数的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### cells_and_names\n",
    "`cells_and_names`方法是一个迭代器，返回网络中每个`Cell`的名字和它的内容本身。\n",
    "\n",
    "用例简单实现了获取与打印每个`Cell`名字的功能，其中根据网络结构可知，存在1个`Cell`为`nn.Conv2d`。\n",
    "\n",
    "其中`nn.Conv2d`是`MindSpore`以Cell为基类封装好的一个卷积层，其具体内容将在“模型层”中进行介绍。\n",
    "\n",
    "代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('', Net1<\n",
      "  (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3),stride=(1, 1),  pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False,weight_init=normal, bias_init=zeros>\n",
      "  >)\n",
      "('conv', Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3),stride=(1, 1),  pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False,weight_init=normal, bias_init=zeros>)\n",
      "-------names-------\n",
      "['conv']\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class Net1(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net1, self).__init__()\n",
    "        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')\n",
    "\n",
    "    def construct(self, x):\n",
    "        out = self.conv(x)\n",
    "        return out\n",
    "\n",
    "net = Net1()\n",
    "names = []\n",
    "for m in net.cells_and_names():\n",
    "    print(m)\n",
    "    names.append(m[0]) if m[0] else None\n",
    "print('-------names-------')\n",
    "print(names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### set_grad\n",
    "`set_grad`接口功能是使用户构建反向网络，在不传入参数调用时，默认设置`requires_grad`为True，需要在计算网络反向的场景中使用。\n",
    "\n",
    "以`TrainOneStepCell`为例，其接口功能是使网络进行单步训练，需要计算网络反向，因此初始化方法里需要使用`set_grad`。\n",
    "\n",
    "`TrainOneStepCell`部分代码如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "class TrainOneStepCell(Cell):\n",
    "    def __init__(self, network, optimizer, sens=1.0):\n",
    "        super(TrainOneStepCell, self).__init__(auto_prefix=False)\n",
    "        self.network = network\n",
    "        self.network.set_grad()\n",
    "        ......\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果用户使用`TrainOneStepCell`等类似接口无需使用`set_grad`， 内部已封装实现。\n",
    "\n",
    "若用户需要自定义此类训练功能的接口，需要在其内部调用，或者在外部设置`network.set_grad`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## nn模块与ops模块的关系\n",
    "MindSpore的nn模块是Python实现的模型组件，是对低阶API的封装，主要包括各种模型层、损失函数、优化器等。\n",
    "\n",
    "同时nn也提供了部分与`Primitive`算子同名的接口，主要作用是对`Primitive`算子进行进一步封装，为用户提供更友好的API。\n",
    "\n",
    "重新分析上文介绍`construct`方法的用例，此用例是MindSpore的`nn.Conv2d`源码简化内容，内部会调用`ops.Conv2D`。`nn.Conv2d`卷积API增加输入参数校验功能并判断是否`bias`等，是一个高级封装的模型层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore import Parameter\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self, in_channels=10, out_channels=20, kernel_size=3):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv2d = ops.Conv2D(out_channels, kernel_size)\n",
    "        self.bias_add = ops.BiasAdd()\n",
    "        self.weight = Parameter(\n",
    "            initializer('normal', [out_channels, in_channels, kernel_size, kernel_size]),\n",
    "            name='conv.weight')\n",
    "\n",
    "    def construct(self, x):\n",
    "        output = self.conv2d(x, self.weight)\n",
    "        output = self.bias_add(output, self.bias)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型层\n",
    "在讲述了`Cell`的使用方法后可知，MindSpore能够以`Cell`为基类构造网络结构。\n",
    "\n",
    "为了方便用户的使用，MindSpore框架内置了大量的模型层，用户可以通过接口直接调用。\n",
    "\n",
    "同样，用户也可以自定义模型，此内容在“构建自定义网络”中介绍。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 内置模型层\n",
    "MindSpore框架在`mindspore.nn`的layer层内置了丰富的接口，主要内容如下：\n",
    "\n",
    "* 激活层\n",
    "\n",
    "    激活层内置了大量的激活函数，在定义网络结构中经常使用。激活函数为网络加入了非线性运算，使得网络能够拟合效果更好。\n",
    "\n",
    "    主要接口有`Softmax`、`Relu`、`Elu`、`Tanh`、`Sigmoid`等。\n",
    "    \n",
    "\n",
    "* 基础层\n",
    "\n",
    "    基础层实现了网络中一些常用的基础结构，例如全连接层、Onehot编码、Dropout、平铺层等都在此部分实现。\n",
    "\n",
    "    主要接口有`Dense`、`Flatten`、`Dropout`、`Norm`、`OneHot`等。\n",
    "    \n",
    "\n",
    "* 容器层\n",
    "\n",
    "    容器层主要功能是实现一些存储多个Cell的数据结构。\n",
    "\n",
    "    主要接口有`SequentialCell`、`CellList`等。\n",
    "    \n",
    "\n",
    "* 卷积层\n",
    "\n",
    "    卷积层提供了一些卷积计算的功能，如普通卷积、深度卷积和卷积转置等。\n",
    "\n",
    "    主要接口有`Conv2d`、`Conv1d`、`Conv2dTranspose`、`Conv1dTranspose`等。\n",
    "    \n",
    "\n",
    "* 池化层\n",
    "\n",
    "    池化层提供了平均池化和最大池化等计算的功能。\n",
    "\n",
    "    主要接口有`AvgPool2d`、`MaxPool2d`和`AvgPool1d`。\n",
    "    \n",
    "\n",
    "* 嵌入层\n",
    "\n",
    "    嵌入层提供word embedding的计算功能，将输入的单词映射为稠密向量。\n",
    "\n",
    "    主要接口有`Embedding`、`EmbeddingLookup`、`EmbeddingLookUpSplitMode`等。\n",
    "    \n",
    "\n",
    "* 长短记忆循环层\n",
    "\n",
    "    长短记忆循环层提供LSTM计算功能。其中`LSTM`内部会调用`LSTMCell`接口，`LSTMCell`是一个LSTM单元，对一个LSTM层做运算，当涉及多LSTM网络层运算时，使用`LSTM`接口。\n",
    "\n",
    "    主要接口有`LSTM`和`LSTMCell`。\n",
    "    \n",
    "\n",
    "* 标准化层\n",
    "\n",
    "    标准化层提供了一些标准化的方法，即通过线性变换等方式将数据转换成均值和标准差。\n",
    "\n",
    "    主要接口有`BatchNorm1d`、`BatchNorm2d`、`LayerNorm`、`GroupNorm`、`GlobalBatchNorm`等。\n",
    "    \n",
    "\n",
    "* 数学计算层\n",
    "\n",
    "    数学计算层提供一些算子拼接而成的计算功能，例如数据生成和一些数学计算等。\n",
    "\n",
    "    主要接口有`ReduceLogSumExp`、`Range`、`LinSpace`、`LGamma`等。\n",
    "    \n",
    "\n",
    "* 图片层\n",
    "\n",
    "    图片计算层提供了一些矩阵计算相关的功能，将图片数据进行一些变换与计算。\n",
    "\n",
    "    主要接口有`ImageGradients`、`SSIM`、`MSSSIM`、`PSNR`、`CentralCrop`等。\n",
    "    \n",
    "\n",
    "* 量化层\n",
    "\n",
    "    量化是指将数据从float的形式转换成一段数据范围的int类型，所以量化层提供了一些数据量化的方法和模型层结构封装。\n",
    "\n",
    "    主要接口有`Conv2dBnAct`、`DenseBnAct`、`Conv2dBnFoldQuant`、`LeakyReLUQuant`等。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 应用实例\n",
    "MindSpore的模型层在`mindspore.nn`下，使用方法如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')\n",
    "        self.bn = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc = nn.Dense(64 * 222 * 222, 3)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.flatten(x)\n",
    "        out = self.fc(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "依然是上述网络构造的用例，从这个用例中可以看出，程序调用了`Conv2d`、`BatchNorm2d`、`ReLU`、`Flatten`和`Dense`模型层的接口。\n",
    "\n",
    "在`Net`初始化方法里被定义，然后在`construct`方法里真正运行，这些模型层接口有序的连接，形成一个可执行的网络。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 损失函数\n",
    "目前MindSpore主要支持的损失函数有`L1Loss`、`MSELoss`、`SmoothL1Loss`、`SoftmaxCrossEntropyWithLogits`、`SampledSoftmaxLoss`、`BCELoss`和`CosineEmbeddingLoss`。\n",
    "\n",
    "MindSpore的损失函数全部是`Cell`的子类实现，所以也支持用户自定义损失函数，其构造方法在“构建自定义网络”中进行介绍。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 内置损失函数\n",
    "* L1Loss\n",
    "\n",
    "    计算两个输入数据的绝对值误差，用于回归模型。`reduction`参数默认值为mean，返回loss平均值结果，若`reduction`值为sum，返回loss累加结果，若`reduction`值为none，返回每个loss的结果。\n",
    "    \n",
    "\n",
    "* MSELoss\n",
    "\n",
    "    计算两个输入数据的平方误差，用于回归模型。`reduction`参数同`L1Loss`。\n",
    "    \n",
    "\n",
    "* SmoothL1Loss\n",
    "\n",
    "    `SmoothL1Loss`为平滑L1损失函数，用于回归模型，阈值`beta`默认参数为1。\n",
    "    \n",
    "\n",
    "* SoftmaxCrossEntropyWithLogits\n",
    "\n",
    "    交叉熵损失函数，用于分类模型。当标签数据不是one-hot编码形式时，需要输入参数`sparse`为True。`reduction`参数默认值为none，其参数含义同`L1Loss`。\n",
    "    \n",
    "\n",
    "* CosineEmbeddingLoss\n",
    "\n",
    "    `CosineEmbeddingLoss`用于衡量两个输入相似程度，用于分类模型。`margin`默认为0.0，`reduction`参数同`L1Loss`。\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 应用实例\n",
    "MindSpore的损失函数全部在mindspore.nn下，使用方法如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.5\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "\n",
    "loss = nn.L1Loss()\n",
    "input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))\n",
    "target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))\n",
    "print(loss(input_data, target_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此用例构造了两个Tensor数据，利用`nn.L1Loss`接口定义了loss，将`input_data`和`target_data`传入loss，执行L1Loss的计算，结果为1.5。若loss = nn.L1Loss(reduction=’sum’)，则结果为9.0。若loss = nn.L1Loss(reduction=’none’)，结果为[[1. 0. 2.] [1. 2. 3.]]。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 优化算法\n",
    "`mindspore.nn.optim`是MindSpore框架中实现各种优化算法的模块，详细说明参见[优化算法](https://www.mindspore.cn/doc/programming_guide/zh-CN/master/optim.html)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建自定义网络\n",
    "无论是网络结构，还是前文提到的模型层、损失函数和优化器等，本质上都是一个`Cell`，因此都可以自定义实现。\n",
    "\n",
    "首先构造一个继承`Cell`的子类，然后在`__init__`方法里面定义算子和模型层等，在`construct`方法里面构造网络结构。\n",
    "\n",
    "以LeNet网络为例，在`__init__`方法中定义了卷积层，池化层和全连接层等结构单元，然后在`construct`方法将定义的内容连接在一起，形成一个完整LeNet的网络结构。\n",
    "\n",
    "LeNet网络实现方式如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode=\"valid\")\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode=\"valid\")\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Dense(120, 84)\n",
    "        self.fc3 = nn.Dense(84, 3)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.max_pool2d(self.relu(self.conv1(x)))\n",
    "        x = self.max_pool2d(self.relu(self.conv2(x)))\n",
    "        x = self.flatten(x)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.0.1",
   "language": "python",
   "name": "mindspore-1.0.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
